Aplicación: Visualización de espacio latente de estrellas variables
Contenido
Aplicación: Visualización de espacio latente de estrellas variables¶
Modelo¶
En este ejemplo práctico utilizaremos un AutoEncoder Variacional (Variational Autoencoder, VAE) (Kingma et al., 2014) para reducir la dimensionalidad de un dataset de curvas de luz de estrellas variables.
El siguiente es un diagrama del modelo:
donde
\(x\) se refiere a los datos observados, en este caso las curvas de luz.
\(z\) se refiere a la variable latente de dimensión reducida que queremos inferir.
\(g_\phi\) es una red neuronal artificial que llamaremos Codificador.
\(f_\theta\) es una red neuronal artificial que llamaremos Decodificador.
Un VAE es un modelo probabilístico generativo donde se consideran los siguientes supuestos:
\(p(z) = \mathcal{N}(0, I)\), es decir la distribución a priori de \(z\) es normal estándar.
\(p(x|z) = \mathcal{N}(\hat \mu, \hat \sigma^2)\), una verosimilitud normal para \(x\).
Se utiliza el decodificador para modelar \(\hat \mu = f_\theta(z)\).
Lo que buscamos es inferir \(z\) a partir de \(x\), es decir el posterior de \(z\):
Como lo anterior es muy difícil de calcular lo reemplazamos por una aproximación variacional. En este caso el posterior variacional es una distribución normal multivariada con covarianza diagonal (sin correlaciones):
donde el codificador se utiliza para modelar \(\mu_i, \sigma_i = g_\phi(x_i)\) de forma amortizada. Además se utiliza el truco de reparametrización (segunda linea de la ecuación).
Implementación¶
Implementaremos un VAE para curvas de luz de estrellas periódicas con dos bandas utilizando en flax, considerando lo siguiente:
El codificador procesa cada banda por separado y luego las combina en un único espacio latente.
El codificador retorna la media y la desviación estándar de la variable latente.
La desviación estándar debe ser no-negativa.
El decodificador recibe la variable latente y genera las curvas de cada banda.
from typing import Sequence, Callable
import jax
import jax.numpy as jnp
import flax.linen as nn
class Encoder(nn.Module):
hidden_units: int
latent_dim: int
activation: Callable = nn.relu
@nn.compact
def __call__(self, x):
g_0 = self.activation(nn.Dense(self.hidden_units)(x[:, 0, :])) # g-band
g_1 = self.activation(nn.Dense(self.hidden_units)(x[:, 1, :])) # r-band
g = self.activation(nn.Dense(self.hidden_units*2)(jnp.concatenate([g_0, g_1], axis=1)))
g = self.activation(nn.Dense(self.hidden_units)(g))
z_loc = nn.Dense(self.latent_dim)(g)
z_scale = nn.softplus(nn.Dense(self.latent_dim)(g))
return z_loc, z_scale
class Decoder(nn.Module):
output_dim: int
hidden_units: int
activation: Callable = nn.relu
@nn.compact
def __call__(self, z):
f = self.activation(nn.Dense(self.hidden_units)(z))
f = self.activation(nn.Dense(self.hidden_units*2)(f))
f_loc0 = self.activation(nn.Dense(self.output_dim)(f))
f_loc1 = self.activation(nn.Dense(self.output_dim)(f))
x_loc0 = nn.Dense(self.output_dim)(f_loc0) # g-band
x_loc1 = nn.Dense(self.output_dim)(f_loc1) # r-band
return x_loc0, x_loc1
Por conveniencia implementaremos también un módulo que llame a los anteriores y realice el truco re-parametrización del espacio latente:
class VAE(nn.Module):
hidden_units: int
latent_dim: int
output_dim: int
@nn.compact
def __call__(self, x, z_rng_key):
z_loc, z_scale = Encoder(self.hidden_units, self.latent_dim)(x)
eps = random.normal(z_rng_key, z_scale.shape)
z = z_loc + eps*z_scale
x_loc0, x_loc1 = Decoder(self.output_dim, self.hidden_units)(z)
return x_loc0, x_loc1, z_loc, z_scale
Preparación de datos¶
A continuación se preparan los datos para entrenar el modelo.
El modelo se entrena sobre curvas en fase (dobladas) que han sido interpoladas a una grilla regular mediante suavizado con kernels.
Se aplica un reescalamiento de tipo MinMax a las curvas interpoladas.
Las curvas en fase se alinean para partir en el mínimo de magnitud (máximo brillo)
import sys
import numpy as np
from sklearn.preprocessing import LabelEncoder
sys.path.append('../src/')
from preprocessing import load_ztf_data, kernel_smoothing
lcs, periods, labels = load_ztf_data()
le = LabelEncoder()
labels_int = le.fit_transform(labels)
pha_interp = np.linspace(0, 1, num=40)
mag_interp = np.zeros(shape=(len(lcs), 2, len(pha_interp)))
err_interp = np.zeros(shape=(len(lcs), 2, len(pha_interp)))
for k, (lc, period) in enumerate(zip(lcs, periods)):
mag_interp[k], err_interp[k], _ = kernel_smoothing(lc, period, pha_interp)
mags_interp_jax = jnp.array(mag_interp)
errs_interp_jax = jnp.array(err_interp)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Entrenamiento¶
Para entrenar el modelo utilizaremos el algoritmo de optimización AMSgrad, que es una extensión de gradiente descedente con tasa de aprendizaje adaptiva (Adam).
A continuación se muestra la función chain de la librería optax la cual permite implementar optimizadores customizados. En este caso el optimizador primero satura las gradientes con norma mayor a 10.0, luego aplica escala los gradients con las reglas adaptivas de AMSgrad y finalmente multiplica por la tasa de aprendizaje inicial:
import optax
optimizer = optax.chain(optax.clip_by_global_norm(10.0),
optax.scale_by_amsgrad(),
optax.scale(-1e-3))
Luego definimos la función de costo de VAE, el Evidence Lower Bound (ELBO) de la aproximación variacional. Matemáticamente esto se define como
donde, bajo los supuestos considerados, el término de la mano derecha tiene la siguiente solución analítica:
Para optimizar el VAE se busca maximizar el ELBO en función de los parámetros del codificador y decodificador.
Utilizaremos compilación JIT para acelerar el cálculo de los gradientes del ELBO
def negELBO(params, mag, err, key):
x_loc0, x_loc1, z_loc, z_scale = model.apply(params, mag, key)
d0 = 0.5*jnp.sum(jnp.square((x_loc0 - mag[:, 0, :])/err[:, 0, :]), axis=-1)
d1 = 0.5*jnp.sum(jnp.square((x_loc1 - mag[:, 1, :])/err[:, 1, :]), axis=-1)
kl_div = 0.5 * jnp.sum(-1 - 2*jnp.log(z_scale) + jnp.square(z_loc) + jnp.square(z_scale), axis=-1)
return jnp.sum(d0 + d1 + kl_div)
grad_loss_jit = jax.jit(jax.value_and_grad(negELBO, argnums=0))
Finalmente inicializamos el modelo y el optimizador y lanzamos la rutina de entrenamiento.
Se entrena por 300 épocas con minibatches de tamaño 32:
import holoviews as hv
hv.extension('bokeh')
import jax.random as random
from tqdm import tqdm
from train_utils import data_loader
key = random.PRNGKey(12345)
model = VAE(output_dim=40, hidden_units=100, latent_dim=2)
key, key_ = random.split(key)
params = model.init(key, jnp.zeros(shape=(1, 2, 40)), key_)
state = optimizer.init(params)
loss_history = []
for epoch in tqdm(range(300)):
loss_epoch = 0.0
key, key_ = random.split(key)
for bmag, berr in data_loader(key_, mags_interp_jax, errs_interp_jax,
batch_size=32, shuffle=True):
key, key_ = random.split(key)
loss_val, grads = grad_loss_jit(params, bmag, berr, key_)
loss_epoch += loss_val.item()
updates, state = optimizer.update(grads, state, params)
params = optax.apply_updates(params, updates)
loss_history.append(loss_epoch/len(lcs))
hv.Curve(loss_history, 'Epoch', 'negative ELBO').opts(width=500, logy=True)
100%|████████| 300/300 [19:41<00:00, 3.94s/it]
Evaluación y visualizaciones¶
Primero utilizamos el modelo para inferir las reconstrucciones y variables latentes del dataset completo:
key, key_ = random.split(key)
x_loc0, x_loc1, z_loc, z_scale = model.apply(params, mag_interp, key_)
A continuación se muestra ejemplos de distintos tipos de estrella variable. Las lineas corresponden a las reconstrucciones del modelo y las datos de entrada (curvas de luz interpoladas).
from plotting import plot_reconstruction, plot_latent_space, plot_latent_generation
hv.Layout([plot_reconstruction(x_loc0[idx], x_loc1[idx], pha_interp,
mag_interp[idx], err_interp[idx], labels[idx]) for idx in [0, 2000, 3550, 3385]]).cols(2)
Luego se visualiza el espacio latente bidimensional. Cada color representa un tipo de estrella variable.
Podemos notar que las clases tienden a separarse en el espacio latente. Notar que el entrenamiento fue no supervisado, la clase no se utilizó para ajustar el modelo.
plot_latent_space(z_loc, z_scale, labels_int, le)
VAE es un modelo generativo. Podemos aprovechar esta capacidad para interpolar en el espacio latente y visualizar las formas de curva de luz que el modelo aprendió a reconstruir.
from functools import partial
z1 = jnp.linspace(-1, 1, 9)
z2 = jnp.linspace(-2, 2, 9)
decoder = Decoder(model.output_dim, model.hidden_units)
generate_lc = partial(jax.jit(decoder.apply), {'params': params['params']['Decoder_0']})
plot_latent_generation(z1, z2, pha_interp, generate_lc)
La reducción no-lineal de dimensionalidad nos permite hacer visualizaciones y explorar un dataset cuando no se tienen etiquetas. También puede utilizarse en estrategias de aprendizaje activo (active learning) para etiquetar un dataset de forma eficiente.
El espacio latente puede utilizarse también como entrada a un clasificador si se cuenta con algunas etiquetas (semi-supervisado).
La capacidad generativa del modelo puede aprovecharse para hacer aumentación de datos y también para interpretar lo que ha aprendido el modelo durante el entrenamiento.
